package gsssh

import (
	"bufio"
	"errors"
	"gitee.com/Sxiaobai/gs/gstool"
	"golang.org/x/crypto/ssh"
	"io"
	"regexp"
	"strings"
	"sync"
	"time"
)

const (
	StatusWait = iota
	StatusRunning
	StatusError
	StatusStop
)

const EndCommand = `THIS_IS_END_POINT`
const SshBroken = `notice : ssh connection is broken`

type Terminal struct {
	//通道
	chanExit       chan struct{} //控制退出
	chanCommand    chan string   //发送命令管道
	chanReceiveMsg chan string   //接收命令管道
	//session
	session *ssh.Session
	//锁
	lockCommand sync.Mutex //命令锁
	lockSocket  sync.Mutex //socket推送锁
	//等待
	waitPty     sync.WaitGroup //等待pty启动完成
	waitCommand sync.WaitGroup //等待命令返回结果
	//状态控制
	runStatus int   //0 待运行 1 运行中 2 运行失败
	runErr    error //失败内容
	//累计运行结果 用于计算是否命令执行完
	runResult               string //命令执行结果
	runReceiveMsg           string //执行命令累计接收到的字符串
	runCommand              string //本次执行的命令
	runHostName             string //主机名 hostname
	runWorkDir              string //登录ssh的账号的工作目录 echo $HOME
	runEndCommand           string //执行命令结束
	runCombineNum           int    //合并推送条数 默认2
	runCloseFirstReceiveMsg bool   //是否不返回执行命令后第一行数据 这一行数据类似于 xxx@iZbp18rsv13t3c3a1hzqikZ: /var/www$ cd /var/www 建议配合funcBeforeCommand使用
	//回调
	funcBeforeCommand   func(command string) string //命令执行前回调 如果返回不为空 那么将会输出 {user} 占位符
	funcReceiveMsg      func(msg string) string     //命令执行后回调 将会以返回值替换传入的值输出
	funcStreamReceive   func(msg string)            //输出结果进行回调 用于socket sse推送等
	funcBroken          func()                      //连接已中断
	funcStartRunCommand func()                      //开始命令
	funcEndRunCommand   func()                      //结束命令
	//异常结束标记
	exceptionList []string
}

func (h *SshConfig) SetCombineNum(combineNum int) {
	h.terminal.runCombineNum = combineNum
}

func (h *SshConfig) CloseFirstReceiveMsg() {
	h.terminal.runCloseFirstReceiveMsg = true
}

func (h *SshConfig) SetFuncBefore(before func(string) string) {
	h.terminal.funcBeforeCommand = before
}

func (h *SshConfig) SetFuncStartCommand(f func()) {
	h.terminal.funcStartRunCommand = f
}

func (h *SshConfig) SetFuncEndCommand(f func()) {
	h.terminal.funcEndRunCommand = f
}

func (h *SshConfig) SetFuncBroken(broken func()) {
	h.terminal.funcBroken = broken
}

func (h *SshConfig) SetFuncReceiveMsg(receive func(string) string) {
	h.terminal.funcReceiveMsg = receive
}

func (h *SshConfig) SetFuncStreamReceive(receive func(string)) {
	h.terminal.funcStreamReceive = receive
}

// RunCommandWait 通过终端一次性执行命令 等待完成
func (h *SshConfig) RunCommandWait(command string) (string, error) {
	defer h.terminal.lockCommand.Unlock()
	h.terminal.lockCommand.Lock()

	if h.terminal.runStatus == StatusStop {
		return ``, errors.New(`连接已停止`)
	}
	if h.terminal.runStatus == StatusError {
		return ``, errors.New(`连接已中断，等待重连`)
	}
	checkErr := h.checkAndRunTerminal()
	if checkErr != nil {
		return ``, checkErr
	}
	if h.terminal.funcStartRunCommand != nil {
		h.terminal.funcStartRunCommand()
	}

	h.terminal.waitCommand.Add(1)
	//回调
	if h.terminal.funcBeforeCommand != nil {
		beforeFuncRet := h.terminal.funcBeforeCommand(command)
		if beforeFuncRet != `` {
			h.streamReceive(gstool.SReplaces(beforeFuncRet, map[string]string{
				`{user}`: h.UserName,
			}))
		}
	}
	h.terminal.chanCommand <- command
	h.RunTimeout()
	h.terminal.waitCommand.Wait()
	h.runTimeoutTicker.Stop()
	if h.terminal.funcEndRunCommand != nil {
		h.terminal.funcEndRunCommand()
	}
	result := h.terminal.runResult
	if h.terminal.runStatus == StatusError && h.terminal.funcBroken != nil {
		h.toChanReceiveMsg(EndCommand)
		h.terminal.funcBroken()
		h.Close()
	}
	return result, nil
}

// RunTimeout 运行超时处理
func (h *SshConfig) RunTimeout() {
	go func() {
		if h.runTimeoutTicker != nil {
			h.runTimeoutTicker.Stop()
		}
		h.runTimeoutTicker = time.NewTicker(time.Duration(h.MaxRunSecond) * time.Second)
		for range h.runTimeoutTicker.C { //超时了，认为失败
			h.streamReceive(`注意：执行超时，本次执行返回`)
			h.toChanReceiveMsg(EndCommand)
			h.runTimeoutTicker.Stop()
		}
	}()

}

func (h *SshConfig) checkAndRunTerminal() error {
	if h.terminal.runStatus == StatusWait {
		h.terminal.waitPty.Add(1)
		go h.startTerminal()
		h.terminal.waitPty.Wait()
		if h.terminal.runErr != nil {
			return h.terminal.runErr
		}
	}
	return nil
}

func (h *SshConfig) startTerminal() {
	h.RunType = RunTypeTerminal
	if h.terminal.runStatus == StatusRunning {
		h.setError(errors.New(`正在运行中`))
		return
	}
	if h.terminal.runStatus == StatusError {
		h.terminal.runStatus = StatusWait
		h.terminal.runErr = nil
	}
	if h.client == nil {
		clientErr := h.ConnectAuthPassword()
		if clientErr != nil {
			h.setError(errors.New(`初始化client失败`))
			return
		}
	}
	var sessionErr error
	h.terminal.session, sessionErr = h.client.NewSession()
	if sessionErr != nil {
		h.setError(sessionErr)
		return
	}
	defer func() {
		if h.terminal.session != nil {
			sessionCloseErr := h.terminal.session.Close()
			if sessionCloseErr != nil {
				h.setError(sessionCloseErr)
			}
		}

	}()
	//启动pty
	modes := ssh.TerminalModes{
		ssh.ECHO:          0,
		ssh.TTY_OP_ISPEED: 14400,
		ssh.TTY_OP_OSPEED: 14400,
	}
	if ptyErr := h.terminal.session.RequestPty("linux", 32, 160, modes); ptyErr != nil {
		h.setError(ptyErr)
		return
	}

	// 将会话的stdout和stderr设置为非阻塞的管道
	stdout, stdoutErr := h.terminal.session.StdoutPipe()
	if stdoutErr != nil {
		h.setError(stdoutErr)
		return
	}
	stderr, stderrErr := h.terminal.session.StderrPipe()
	if stderrErr != nil {
		h.setError(stderrErr)
		return
	}
	stdin, stdinErr := h.terminal.session.StdinPipe()
	if stdinErr != nil {
		h.setError(stdinErr)
		return
	}
	//启动
	if shellErr := h.terminal.session.Shell(); shellErr != nil {
		h.setError(shellErr)
		return
	}
	//初始化
	h.initParams()
	//接收终端输出
	go h.receiveMsg(stdout, stderr)
	// 发送命令到会话
	go h.receiveCommand(stdin)
	// 接收消息
	go h.combineMsg()
	h.terminal.runStatus = StatusRunning
	h.terminal.waitPty.Done()
	waitErr := h.terminal.session.Wait()
	if waitErr != nil {
		//状态置为异常
		h.setError(waitErr)
		//输出结束信号
		_ = h.toChanReceiveMsg(SshBroken)
		_ = h.toChanReceiveMsg(EndCommand)
		time.Sleep(time.Second)
		//等待回调重连
		h.Close()
	}
	return
}

func (h *SshConfig) toChanReceiveMsg(msg string) (err error) {
	err = nil
	defer func() {
		if r := recover(); r != nil {
			h.Errof(`尝试写入msg：%s失败 %v`, msg, r)
			err = gstool.Error(`尝试写入msg：%s失败 %v`, msg, r)
		}
	}()
	h.terminal.chanReceiveMsg <- msg
	return err
}

func (h *SshConfig) initParams() {
	h.terminal.chanCommand = make(chan string, 1) //为了保证同步执行 这里只允许一次执行一条命令
	h.terminal.chanReceiveMsg = make(chan string)
	//这里执行会输出类似于xxx@iZbp18rsv13t3c3a1hzqikZ: /var/www/yiishell$  区别于正常的输出，冒号后有一个空格且路径为全部路径
	h.terminal.runEndCommand = strings.Replace(`echo "$(printf '{user_name}@%s ' $(hostname):)$(pwd)"$`, `{user_name}`, h.UserName, -1)
	h.terminal.runCombineNum = 2
	h.terminal.exceptionList = []string{`-bash: syntax error`, SshBroken} //这种异常标记的是不会输出结束标记的
}

// 等待输入命令 执行
func (h *SshConfig) receiveCommand(stdin io.WriteCloser) {
	for {
		select {
		case command, ok := <-h.terminal.chanCommand:
			if !ok {
				return
			}
			cm := "(" + command + ") ; " + h.terminal.runEndCommand + ";" + "echo " + EndCommand + " \n"
			h.terminal.runCommand = command
			_, writeErr := stdin.Write([]byte(cm))
			if writeErr != nil {
				_ = h.toChanReceiveMsg(SshBroken)
				_ = h.toChanReceiveMsg(EndCommand)
				h.setError(writeErr)
				return
			}
		}
	}
}

// 接收ssh返回的消息
func (h *SshConfig) receiveMsg(std, stderr io.Reader) {
	go func() {
		scanner := bufio.NewScanner(io.MultiReader(std, stderr))
		for scanner.Scan() {
			receiveMsg := scanner.Text()
			if h.terminal.runStatus == StatusWait {
				continue
			}
			err := h.toChanReceiveMsg(receiveMsg)
			if err != nil {
				return
			}
		}
	}()
}

func (h *SshConfig) combineMsg() {
	combineMsg := ``
	combineNum := 0
	regReplace := regexp.MustCompile(`(` + h.UserName + `[@].*[$]\ )`)
	regReplaceCustom := regexp.MustCompile(`(` + h.UserName + `[@].*[$])`)
	for {
		select {
		case msg, ok := <-h.terminal.chanReceiveMsg:
			if !ok {
				return
			}
			//移除颜色标记
			msg = strings.Replace(msg, "\x1b[01;32m", "", -1)
			if h.terminal.funcReceiveMsg != nil {
				msg = h.terminal.funcReceiveMsg(msg)
			}
			combineMsg += msg + "\n"
			h.terminal.runReceiveMsg += combineMsg
			combineNum++
			//是否推送socket
			if h.isStreamReceive(combineMsg, combineNum) {
				//将执行的命令塞入返回的标记中
				if h.terminal.runCloseFirstReceiveMsg {
					combineMsg = regReplace.ReplaceAllString(combineMsg, ``)
				} else {
					combineMsg = regReplace.ReplaceAllString(combineMsg, `${1}`+h.terminal.runCommand+"\n")
				}
				//移除结束字符串标记
				combineMsg = strings.Replace(combineMsg, EndCommand, ``, -1)
				//仅有换行不处理
				if combineMsg != "\n" {
					h.streamReceive(strings.TrimRight(combineMsg, "\n"))
				}
				//还原
				combineMsg = ``
				combineNum = 0
			}
			//返回命令执行结果
			findResultIndex := strings.Index(h.terminal.runReceiveMsg, EndCommand)
			//遇到错误 直接取到尾
			if gstool.SContains(h.terminal.runReceiveMsg, h.terminal.exceptionList) {
				findResultIndex = len(h.terminal.runReceiveMsg)
			}
			if (findResultIndex != -1 || gstool.SContains(h.terminal.runReceiveMsg, h.terminal.exceptionList)) && h.terminal.runCommand != `` {
				h.terminal.runResult = h.terminal.runReceiveMsg[0:findResultIndex]
				//有时会带有一个标记头，对结果来说是没有用的，需要移除，否则会影响找pid等命令
				h.terminal.runResult = regReplace.ReplaceAllString(h.terminal.runResult, ``)
				h.terminal.runResult = regReplaceCustom.ReplaceAllString(h.terminal.runResult, ``)
				h.terminal.runReceiveMsg = h.terminal.runReceiveMsg[findResultIndex:]
				//移除结束标记
				h.terminal.runReceiveMsg = strings.Replace(h.terminal.runReceiveMsg, EndCommand, ``, -1)
				h.terminal.runCommand = ``
				h.terminal.waitCommand.Done()
			}
		}
	}
}

func (h *SshConfig) isStreamReceive(combineMsg string, combineNum int) bool {
	if strings.Index(combineMsg, EndCommand) != -1 || combineNum == h.terminal.runCombineNum || gstool.SContains(combineMsg, h.terminal.exceptionList) {
		return true
	}
	return false
}

func (h *SshConfig) streamReceive(msg string) {
	if h.terminal.funcStreamReceive != nil {
		h.terminal.funcStreamReceive(msg)
	}
}

func (h *SshConfig) setError(err error) {
	h.terminal.runStatus = StatusError
	h.terminal.runErr = err
}

// CloseTerminal 主动关闭
func (h *SshConfig) CloseTerminal() {
	if h.terminal.session == nil {
		return
	}
	if h.terminal.runStatus == StatusWait {
		return
	}
	h.terminal.runStatus = StatusStop
	if h.terminal.session != nil {
		closeSessionErr := h.terminal.session.Close()
		if closeSessionErr != nil {
			//这里可能会报错 不管
		}
	}
	if h.terminal.chanCommand != nil {
		close(h.terminal.chanCommand)
	}
	if h.terminal.chanReceiveMsg != nil {
		close(h.terminal.chanReceiveMsg)
	}
	h.terminal.session = nil
}
